require(colorout)
require(sqldf)
require(arm)
require(rstan)
require(foreach)
require(doMC)
require(RPostgreSQL)

remove(list=objects())
options(digits=2, scipen=9, width=110, java.parameters = "-Xrs")

################################################################################################################
# helper functions

username <- system("echo $USER", intern=TRUE)
conn <- dbConnect(dbDriver("PostgreSQL"), user=username, password="", host="localhost", port=5432, dbname=username)
SQL <- function(x) 
  dbGetQuery(conn, x)
SQL_EXECUTE <- function(x)
  system(paste0("psql -d $USER -U $USER -c '", x, "';"))

################################################################################################################
# load samples

load("yg_pa_mrp_mcmc.RData")
samples <- extract(M)
# print(M)

set.seed(12345)
n_samples <- 100
samples_ix <- sample(1:length(samples$alpha), n_samples)

################################################################################################################
# get combinatorial expansion of categories that we need for scoring

full_cats <- SQL(paste0("
  select 
  ", paste0(xs_cat, collapse=", "), "
  from poststrat
  group by ", paste0(1:length(xs_cat), collapse=", "), "
"))

################################################################################################################
# compile alphas and betas

registerDoMC(5)
ab <- foreach (i_sample = 1:n_samples) %dopar% {
  message("Compiling alphas and betas, ", i_sample, " of ", n_samples)
  ix <- samples_ix[i_sample]

  a <- rep(samples$alpha[ix], nrow(full_cats))
  # base terms
  for (i_term in 1:length(xs_cat)) {
    x <- xs_cat[i_term]
    tmp <- samples[[paste0("alpha_", x)]][ix,] * 
      samples[["sigma_alpha"]][ix, i_term] * 
      samples[["sigma_sigma_alpha"]][ix]
    names(tmp) <- cat_lookups[[x]]
    tmp <- as.numeric(tmp[as.character(full_cats[,x])])
    tmp[is.na(tmp)] <- 0
    a <- a + tmp
  }
  # interactions
  for (i_term in 1:nrow(interactions)) {
    x <- interactions$x[i_term]
    tmp <- samples[[paste0("alpha_", x)]][ix,] * 
      samples[["sigma_alpha"]][ix, length(xs_cat) + i_term] * 
      samples[["sigma_sigma_alpha"]][ix]
    names(tmp) <- interaction_lookups[[i_term]]
    tmp <- as.numeric(tmp[paste0(full_cats[,interactions$x1[i_term]], "__", full_cats[,interactions$x2[i_term]])])
    tmp[is.na(tmp)] <- 0
    a <- a + tmp
  }

  b <- rep(samples$beta[ix], nrow(full_cats))
  for (i_term in 1:length(xs_slope)) {
    x <- xs_slope[i_term]
    tmp <- samples[[paste0("beta_", x)]][ix,] * 
      samples[["sigma_beta"]][ix, i_term] * 
      samples[["sigma_sigma_beta"]][ix]
    names(tmp) <- cat_lookups[[x]]
    tmp <- as.numeric(tmp[as.character(full_cats[,x])])
    tmp[is.na(tmp)] <- 0
    b <- b + tmp
  }

  return(list(a=a, b=b))
}

alphas <- sapply(1:n_samples, function(i) ab[[i]]$a)
betas <- sapply(1:n_samples, function(i) ab[[i]]$b)
colnames(alphas) <- paste0("a", 1:n_samples)
colnames(betas) <- paste0("b", 1:n_samples)

z_betas <- t(sapply(1:n_samples, function(i_sample) {
  ix <- samples_ix[i_sample]
  sapply(1:length(xs_num), function(i) samples[["beta_z"]][ix, i])
}))

# upload to postgres
out <- data.frame(
  full_cats, alphas, betas, 
  stringsAsFactors=FALSE)
SQL_EXECUTE("drop table if exists ab;")
dbWriteTable(conn, "ab", out, row.names=TRUE, append=FALSE)

################################################################################################################
# compute yhats

# pre-compute all X's to speed computation
SQL_EXECUTE(paste0("
  drop table if exists xs; 
  create table xs as 
  select randomid, voted2012, fips, 
  ", paste0(colnames(full_cats), collapse=", "), ", 
  case when coalesce(dem2way2008, ", xs_num_mu["dem2way2008"], ") < 0.005 then -5.2933 
       when coalesce(dem2way2008, ", xs_num_mu["dem2way2008"], ") > 0.995 then 5.2933 
       else (-1 * ln(1/coalesce(dem2way2008, ", xs_num_mu["dem2way2008"], ") - 1)) end as dem2way2008, 
  ", paste0("(coalesce(", xs_num, ", ", xs_num_mu[xs_num], ") - ", xs_num_mu[xs_num], ") / (2 * ", xs_num_sd[xs_num], ") as ", xs_num, collapse=", "), "
  from poststrat;
"))

# yhats, pre-correction
SQL_EXECUTE(paste0("
  drop table if exists v1; 
  create table v1 as
  select randomid, xs.state, voted2012, fips, 
  cast(NULL as numeric) as yhat, ", 
  paste0("
    round(cast(
      a", 1:n_samples, " + 
      b", 1:n_samples, " * dem2way2008 +
      ", sapply(1:n_samples, function(i) paste0(z_betas[i,], " * ", xs_num, collapse=" + ")), 
    " as numeric), 3) as yhat", 1:n_samples, collapse=", "), "
  from xs 
  inner join ab on (", paste0("xs.", colnames(full_cats), "=ab.", colnames(full_cats), collapse=" and "), "); 
"))

SQL_EXECUTE(paste0("
  update v1 set yhat = round((", paste0("yhat", 1:n_samples, collapse=" + "), ") / ", n_samples, ", 3);
"))

